Report: Triplet Classification

Introduction

We will quickly examine an event with all data included from Kaggle. This is event number 9999, the last event in the whole dataset. Chosen because we will run classification on it later and therefore it shouldn't be in the training set of earlier events.

In [95]:
event_file = '/global/cscratch1/sd/danieltm/ExaTrkX/trackml/train_all/event000009999'
hits, particles, truth = trackml.dataset.load_event(
        event_file, parts=['hits', 'particles', 'truth'])

We will leave noise hits in for this intro:

In [96]:
pt_min = 0
vlids = [(8,2), (8,4), (8,6), (8,8),
             (13,2), (13,4), (13,6), (13,8),
             (17,2), (17,4)]
n_det_layers = len(vlids)
# Select barrel layers and assign convenient layer number [0-9]
vlid_groups = hits.groupby(['volume_id', 'layer_id'])
hits = pd.concat([vlid_groups.get_group(vlids[i]).assign(layer=i)
                  for i in range(n_det_layers)])
In [97]:
# Calculate PER-HIT particle transverse momentum
pt = np.sqrt(truth.tpx**2 + truth.tpy**2)
truth = truth.assign(pt=pt)
# Calculate derived hits variables
r = np.sqrt(hits.x**2 + hits.y**2)
phi = np.arctan2(hits.y, hits.x)
# Select the data columns we need
hits = (hits[['hit_id', 'x', 'y', 'z', 'layer']]
        .assign(r=r, phi=phi)
        .merge(truth[['hit_id', 'particle_id', 'pt']], on='hit_id'))

We now have a hit list that has cylindrical co-ordinates, and...

In [45]:
len(np.unique(hits.particle_id))
Out[45]:
6715

unique tracks, including one "track" of all noise hits with particle ID = 0. We can visualise some of these tracks...

In [75]:
plt.figure(figsize=(10,10))
for pid in np.unique(hits.particle_id)[:20]:
    pid_hits = hits[hits['particle_id'] == pid]
    size = 1 if pid==0 else 50
    plt.scatter(pid_hits.x, pid_hits.y, s=size)
#     time.sleep(1)

where the blue points are noise hits, and the colors denote hits from the same track.

Consider the pT distribution:

In [109]:
sns.distplot(hits[hits.pt < 3].pt)
Out[109]:
<matplotlib.axes._subplots.AxesSubplot at 0x2aab3c038d30>

That is, most of the hits are from tracks below 0.5 GeV. These may be misleading, because this is counting HITS, not tracks. We can histogram over tracks:

In [108]:
sns.distplot(hits[hits.pt < 3].groupby('particle_id')['pt'].mean())
Out[108]:
<matplotlib.axes._subplots.AxesSubplot at 0x2aab3ca7ec18>
In [171]:
low_pt_hits = hits[hits.pt < 0.3]
In [172]:
plt.figure(figsize=(10,10))
for pid in np.unique(low_pt_hits.particle_id)[10:20]:
    pid_hits = hits[hits['particle_id'] == pid]
    size = 1 if pid==0 else 50
    plt.scatter(pid_hits.x, pid_hits.y, s=size)

To make our job as hard as possible, let's try to classify a track that is low pT, AND has duplicate hits on a layer.

In [236]:
layer_count = low_pt_hits.groupby(['particle_id', 'layer']).count()
duplicate_pids = layer_count[layer_count.hit_id > 1]
unique_duplicate_pids = np.unique([row[0] for row in duplicate_pids.index])

Meet Brian

Out of the total number of tracks, there are...

In [56]:
len(unique_duplicate_pids)
---------------------------------------------------------------------------
NameError                                 Traceback (most recent call last)
<ipython-input-56-afd88ece3c2e> in <module>
----> 1 len(unique_duplicate_pids)

NameError: name 'unique_duplicate_pids' is not defined

Low pT tracks with duplicate hits. These look like...

In [262]:
plt.figure(figsize=(10,10))
for pid in unique_duplicate_pids[-10:]:
    pid_hits = hits[hits['particle_id'] == pid]
    size = 1 if pid==0 else 50
    plt.scatter(pid_hits.x, pid_hits.y, s=size)

Look at that beautiful red track with multiple duplicates and a huge curvature. It has particle ID #846680715692089346 and event ID #9999, but let's call it Brian. Our task will be to classify all tracks correctly, but focus on Brian as a tough case.

In [2]:
brian = 846680715692089346

Graph Construction

Introduction

We will try multiple methods of constructing a graph from the hit data. For these constructions, we will work with no pT cut or duplicate restriction - all real hits are fair game. However, we will remove noise.

In [10]:
# Geometric and physics cuts
pt_min = 0.
phi_slope_max = .001
z0_max = 200

# Graph features and scale
feature_names = ['r', 'phi', 'z']
n_phi_sections = 8
feature_scale = np.array([1000., np.pi / n_phi_sections, 1000.])

Heuristic Construction

We load the set of TrackML data, and run the constructor to get training and testing data. Training from the front, testing from the back. For visualisation purposes, we will just look at one test example, event #9999.

In [28]:
def calc_eta(r, z):
    theta = np.arctan2(r, z)
    return -1. * np.log(np.tan(theta / 2.))

def calc_dphi(phi1, phi2):
    """Computes phi2-phi1 given in range [-pi,pi]"""
    dphi = phi2 - phi1
    dphi[dphi > np.pi] -= 2*np.pi
    dphi[dphi < -np.pi] += 2*np.pi
    return dphi

def split_detector_sections(hits, phi_edges, eta_edges):
    """Split hits according to provided phi and eta boundaries."""
    hits_sections = []
    # Loop over sections
    for i in range(len(phi_edges) - 1):
        phi_min, phi_max = phi_edges[i], phi_edges[i+1]
        # Select hits in this phi section
        phi_hits = hits[(hits.phi > phi_min) & (hits.phi < phi_max)]
        # Center these hits on phi=0
        centered_phi = phi_hits.phi - (phi_min + phi_max) / 2
        phi_hits = phi_hits.assign(phi=centered_phi, phi_section=i)
        for j in range(len(eta_edges) - 1):
            eta_min, eta_max = eta_edges[j], eta_edges[j+1]
            # Select hits in this eta section
            eta = calc_eta(phi_hits.r, phi_hits.z)
            sec_hits = phi_hits[(eta > eta_min) & (eta < eta_max)]
            hits_sections.append(sec_hits.assign(eta_section=j))
    return hits_sections

def select_hits(hits, truth, particles, pt_min=0):
    # Barrel volume and layer ids
    vlids = [(8,2), (8,4), (8,6), (8,8),
             (13,2), (13,4), (13,6), (13,8),
             (17,2), (17,4)]
    n_det_layers = len(vlids)
    # Select barrel layers and assign convenient layer number [0-9]
    vlid_groups = hits.groupby(['volume_id', 'layer_id'])
    hits = pd.concat([vlid_groups.get_group(vlids[i]).assign(layer=i)
                      for i in range(n_det_layers)])
    # Calculate PER-HIT particle transverse momentum
    pt = np.sqrt(truth.tpx**2 + truth.tpy**2)
    truth = truth.assign(pt=pt)
    # Calculate derived hits variables
    r = np.sqrt(hits.x**2 + hits.y**2)
    phi = np.arctan2(hits.y, hits.x)
    # Select the data columns we need
    hits = (hits[['hit_id', 'x', 'y', 'z', 'layer']]
            .assign(r=r, phi=phi)
            .merge(truth[['hit_id', 'particle_id', 'pt']], on='hit_id'))
    # Remove noise
    hits = hits[hits.particle_id != 0]
    return hits    

def select_segments(hits1, hits2, phi_slope_max, z0_max):
    """
    Construct a list of selected segments from the pairings
    between hits1 and hits2, filtered with the specified
    phi slope and z0 criteria.
    Returns: pd DataFrame of (index_1, index_2), corresponding to the
    DataFrame hit label-indices in hits1 and hits2, respectively.
    """
    # Start with all possible pairs of hits
    keys = ['evtid', 'r', 'phi', 'z', 'particle_id', 'hit_id']
    hit_pairs = hits1[keys].reset_index().merge(
        hits2[keys].reset_index(), on='evtid', suffixes=('_1', '_2'))
    # Compute line through the points
    dphi = calc_dphi(hit_pairs.phi_1, hit_pairs.phi_2)
    dz = hit_pairs.z_2 - hit_pairs.z_1
    dr = hit_pairs.r_2 - hit_pairs.r_1
    phi_slope = dphi / dr
    z0 = hit_pairs.z_1 - hit_pairs.r_1 * dz / dr
    # Filter segments according to criteria
    good_seg_mask = (phi_slope.abs() < phi_slope_max) & (z0.abs() < z0_max)
    return hit_pairs[['index_1', 'index_2']][good_seg_mask]

def construct_graph(hits, layer_pairs,
                              phi_slope_max, z0_max,
                              feature_names,
                              feature_scale):
    layer_groups = hits.groupby('layer')
    segments = []
    for (layer1, layer2) in layer_pairs:
        # Find and join all hit pairs
        try:
            hits1 = layer_groups.get_group(layer1)
            hits2 = layer_groups.get_group(layer2)
        # If an event has no hits on a layer, we get a KeyError.
        # In that case we just skip to the next layer pair
        except KeyError as e:
            logging.info('skipping empty layer: %s' % e)
            continue
        # Construct the segments
        segments.append(select_segments(hits1, hits2, phi_slope_max, z0_max))
        # Combine segments from all layer pairs
    segments = pd.concat(segments)
    
#     print("Segments selected in", event_file[-4:])
    
    X = (hits[feature_names].values / feature_scale).astype(np.float32)
    pid = (hits['particle_id'].values).astype(np.int64)
    I = (hits['hit_id'].values).astype(np.int64)
    n_edges = len(segments)
    n_hits = len(hits)
    
    pid1 = hits.particle_id.loc[segments.index_1].values
    pid2 = hits.particle_id.loc[segments.index_2].values
    y = np.zeros(n_edges, dtype=np.float32)
    y[:] = (pid1 == pid2)
    
    hit_idx = pd.Series(np.arange(n_hits), index=hits.index)
    seg_start = hit_idx.loc[segments.index_1].values
    seg_end = hit_idx.loc[segments.index_2].values
    
    e = np.vstack([seg_start, seg_end])
    
    data = Data(x = torch.from_numpy(X).float(), edge_index = torch.from_numpy(e), y = torch.from_numpy(y), I = torch.from_numpy(I), pid=torch.from_numpy(pid))
    
    return data
    

def build_event(event_file, pt_min, phi_slope_max, z0_max, feature_names, feature_scale, n_phi_sections=1, n_eta_sections=1):
    hits, particles, truth = trackml.dataset.load_event(
        event_file, parts=['hits', 'particles', 'truth'])
    hits = select_hits(hits, truth, particles, pt_min=pt_min).assign(evtid=int(event_file[-9:]))
    
    phi_range, eta_range = [-np.pi, np.pi], [-5, 5]
    phi_edges = np.linspace(*phi_range, num=n_phi_sections+1)
    eta_edges = np.linspace(*eta_range, num=n_eta_sections+1)
    hits_sections = split_detector_sections(hits, phi_edges, eta_edges)
    
    # Define adjacent layers
    n_det_layers = 10
    l = np.arange(n_det_layers)
    layer_pairs = np.stack([l[:-1], l[1:]], axis=1)
    
    graphs_all = [construct_graph(section_hits, layer_pairs=layer_pairs,
                              phi_slope_max=phi_slope_max, z0_max=z0_max,
                              feature_names=feature_names,
                              feature_scale=feature_scale)
                              for section_hits in hits_sections]
    
    return graphs_all

def prepare_event(event_file, pt_min, phi_slope_max, z0_max, feature_names, feature_scale, n_phi_sections=1, iter=None, num_samples=None, out=None):
    
    graphs_all = build_event(event_file, pt_min, phi_slope_max, z0_max, feature_names, feature_scale, n_phi_sections)

    if iter is not None and num_samples is not None:
        out.update(progress(iter, num_samples))    

    return graphs_all
In [29]:
HC_input_dir = "/global/cscratch1/sd/danieltm/ExaTrkX/trackml/train_all/"
all_events = os.listdir(HC_input_dir)
all_events = [HC_input_dir + event[:14] for event in all_events]
all_events = list(set(all_events))
all_events.sort()
In [30]:
train_size, test_size = 1, 1
out = display(progress(0, train_size), display_id=True)
HC_train_dataset = [prepare_event(event_file, pt_min, phi_slope_max, z0_max, feature_names, feature_scale, n_phi_sections, iter, train_size, out) for (event_file, iter) in zip(all_events[:train_size], range(train_size))]
HC_train_dataset = [datapoint for dataset in HC_train_dataset for datapoint in dataset]
out = display(progress(0, test_size), display_id=True)
HC_test_dataset = [prepare_event(event_file, pt_min, phi_slope_max, z0_max, feature_names, feature_scale, n_phi_sections, iter, test_size, out) for (event_file, iter) in zip(all_events[-test_size:], range(test_size))]
HC_test_dataset = [datapoint for dataset in HC_test_dataset for datapoint in dataset]
HC_train_loader = DataLoader(HC_train_dataset, batch_size=2, shuffle=True)
HC_test_loader = DataLoader(HC_test_dataset, batch_size=2, shuffle=True)
0
0

Finding Brian:

In [19]:
g1 = HC_test_loader.dataset[7]
g2 = HC_test_loader.dataset[6]
X = g1.x.numpy() * feature_scale
hits = np.vstack([X.T, g1.I.numpy(), g1.pid.numpy()]).T
sum(hits[:,4] == brian)
---------------------------------------------------------------------------
NameError                                 Traceback (most recent call last)
<ipython-input-19-40740b8870e7> in <module>
----> 1 g1 = HC_test_loader.dataset[7]
      2 g2 = HC_test_loader.dataset[6]
      3 X = g1.x.numpy() * feature_scale
      4 hits = np.vstack([X.T, g1.I.numpy(), g1.pid.numpy()]).T
      5 sum(hits[:,4] == brian)

NameError: name 'HC_test_loader' is not defined
In [76]:
np.where(np.unique(hits[:,4])==brian)
Out[76]:
1185
In [77]:
plt.figure(figsize=(10,10))
for pid in np.unique(hits[:,4])[1170:1180]:
    pid_hits = hits[hits[:,4] == pid]
    x = pid_hits[:,0] * np.cos(pid_hits[:,1])
    y = pid_hits[:,0] * np.sin(pid_hits[:,1])
    plt.scatter(x, y, s=50)

Where our Brian is the grey track above.

In [4]:
def plot_brian(event, feature_scale, brian, phi_min = -np.pi, phi_max = np.pi, r_min = 0, r_max = 1200):
    X = event.x.numpy() * feature_scale
    X_filter = (X[:,1] > phi_min) & (X[:,1] < phi_max) & (X[:,0] >= r_min) & (X[:,0] <= r_max)
    brian_filter = event.pid.numpy() == brian
    
    x = X[:,0] * np.cos(X[:,1])
    y = X[:,0] * np.sin(X[:,1])
    
    plt.figure(figsize=(10,10))
    e = event.edge_index.numpy()

    brian_edges = e[:,(X_filter[e[0,:]]) & (X_filter[e[1,:]]) & (brian_filter[e[0,:]]) & (brian_filter[e[1,:]])]
    print(brian_edges)
    e = e[:,(X_filter[e[0,:]]) & (X_filter[e[1,:]])]
    X_ind = np.arange(len(X))
    e_filter = (np.isin(X_ind, e[0,:])) | (np.isin(X_ind, e[1,:]))
        
    plt.plot([x[e[0,:]], x[e[1,:]]], [y[e[0,:]], y[e[1,:]]], c='b', alpha=0.01)    
    plt.plot([x[brian_edges[0,:]], x[brian_edges[1,:]]], [y[brian_edges[0,:]], y[brian_edges[1,:]]], c='r', alpha=0.9)
    
    plt.scatter(x[X_filter & e_filter & ~brian_filter], y[X_filter & e_filter & ~brian_filter], c='k', alpha=0.01)
    plt.scatter(x[X_filter & brian_filter], y[X_filter & brian_filter], c='r', s=50, alpha=0.9)
    
In [81]:
%%time
plot_brian(g1, feature_scale, brian, r_min = 80, r_max=650)
[[2323 2323 2960 2960 2964 2964 3549]
 [2960 2964 3545 3549 3545 3549 4059]]
CPU times: user 8.97 s, sys: 165 ms, total: 9.13 s
Wall time: 9.13 s
In [51]:
%%time
plot_brian(g2, feature_scale, brian, r_min = 0, r_max=1400)
[]
CPU times: user 28.9 s, sys: 651 ms, total: 29.6 s
Wall time: 29.6 s

Clearly, the heuristic construction didn't nail Brian. Even before classification, he is only half constructed, split between graphs and with some hits too far from each other to connect with doublets.

Embedding Construction

In [5]:
def reset_hit_indices(X, e, I, pid):
    U = list(set(e[0,:]) | set(e[1,:]))
    newX = X[U]
    newI = I[U]
    newpid = pid[U]
    Reverse_U = np.zeros(len(X), dtype=int)
    Reverse_U[U] = np.arange(len(U))
    newE = np.zeros((2,e.shape[1]), dtype=int)
    newE[0,:] = Reverse_U[e[0,:]]
    newE[1,:] = Reverse_U[e[1,:]]
       
    return newX, newE, newI, newpid

def augment_graph(e, y):
    
    low_edge_count = [k for k,v in Counter(np.hstack([e[0,:], e[1,:]])).items() if v < 4]
    
    new_edges = []
    for i, hit in enumerate(low_edge_count):
        random_edge = random.choice(low_edge_count[:i]+low_edge_count[i+1:])
        new_edges.append([hit, random_edge])
    
    new_edges = np.array(new_edges).T
    e = np.hstack((e, new_edges))
    y = np.hstack([y, np.zeros(new_edges.shape[1])])
    
    return e, y

def construct_AE_graph(X, e, y, I, pid, norm_phi_min, delta, n_phi_sections, augmented):
    """Construct one graph (e.g. from one event)"""
    """ Masks out the edges and hits falling within [phi_min, phi_min + delta_phi]"""
    
    # Mask out phi segment edges and truths, leaving full hit list
    seg_hits = (X[:,1] >= norm_phi_min) & (X[:,1] < (norm_phi_min + delta))
    seg_edges = seg_hits[e[1,:]] # Whether to filter by in or out end may impact training, explore this further!
    e_sec = e[:, seg_edges]
    y_sec = y[seg_edges]
    
    # Prepare the features data types    
    y_sec = y_sec.astype(np.float32)
    
    # Option to augment graph with random connections
    if augmented:
        e_sec, y_sec = augment_graph(e_sec, y_sec)
    
    # Reset hit indices to avoid unused hits
    X, e_sec, I, pid = reset_hit_indices(X, e_sec, I, pid)

    # Center phi at 0
    X[:,1] = X[:,1] - norm_phi_min - (delta/2)
    # Handle case of crossing the boundary
    X[X[:,1] < (-n_phi_sections), 1] += 2*n_phi_sections
    X[X[:,1] > (n_phi_sections), 1] -= 2*n_phi_sections
    
    
    graph = Data(x = torch.from_numpy(X).float(), edge_index = torch.from_numpy(e_sec), y = torch.from_numpy(y_sec), I = torch.from_numpy(I), pid = torch.from_numpy(pid))
    
    # Return a tuple of the results
    return graph

def prepare_AE_event(event, feature_names, feature_scale, pt_min, n_phi_sections=1, iter=None, num_samples=None, out=None, augmented=False):
    
    event = np.load(event, allow_pickle=True)
    
    hits, truth, e, scores = event['hits'], event['truth'].reset_index(drop=True), event['neighbors'], event['scores']
    
    # Calculate derived hits variables
    r = np.sqrt(hits.x**2 + hits.y**2)
    phi = np.arctan2(hits.y, hits.x)
    # Select the data columns we need
    hits = hits.assign(r=r, phi=phi)
    
    I = hits['hit_id'].to_numpy().astype(np.float32)
    X = (hits[feature_names].values / feature_scale).astype(np.float32)
    truth = truth.to_numpy()
    pt_mask = truth[:,9] > pt_min
    e = e[:,pt_mask[e[0,:]] & pt_mask[e[1,:]]]
    
    # Remove duplicate edges
    r_mask = X[e[0,:],0] > X[e[1,:],0]
    e[0,r_mask], e[1,r_mask] = e[1,r_mask], e[0,r_mask]
    e = np.array(list(set(list(zip(e[0,:], e[1,:]))))).T
    
    pid = truth[:,1]
    y = (pid[e[0,:]] == pid[e[1,:]]).astype(np.float32)
    
    data = []
    
    delta = 2
    
    for i in range(n_phi_sections):
        
        norm_phi_min = -n_phi_sections + (i*delta)
        graph = construct_AE_graph(X, e, y, I, pid,  norm_phi_min, delta, n_phi_sections, augmented=augmented)
        data.append(graph)
    
    
    if iter is not None and num_samples is not None:
        out.update(progress(iter, num_samples))    
    
    return data

We will load graphs constructed using an Adjacent Embedding (AE) model. That is, the embedding is trained on hits from adjacent layers, but may not necessarily construct edges only between adjacent layers. Let's examine what it gives us.

In [7]:
AE_input_dir = "/global/cscratch1/sd/danieltm/ExaTrkX/processed_sparse/adjacent/graphs"
all_events = os.listdir(AE_input_dir)
all_events = [os.path.join(AE_input_dir, event) for event in all_events]
all_events.sort()
In [7]:
%%time
train_size, test_size = 1, 1
out = display(progress(0, train_size), display_id=True)
AE_train_dataset = [prepare_AE_event(event, feature_names, feature_scale, pt_min, n_phi_sections, iter, train_size, out) for (iter, event) in enumerate(all_events[:train_size])]
AE_train_dataset = [datapoint for dataset in AE_train_dataset for datapoint in dataset]
out = display(progress(0, test_size), display_id=True)
AE_test_dataset = [prepare_AE_event(event, feature_names, feature_scale, pt_min, n_phi_sections, iter, test_size, out) for (iter, event) in enumerate(all_events[-test_size:])]
AE_test_dataset = [datapoint for dataset in AE_test_dataset for datapoint in dataset]
AE_train_loader = DataLoader(AE_train_dataset, batch_size=2, shuffle=True)
AE_test_loader = DataLoader(AE_test_dataset, batch_size=2, shuffle=True)
0
0
CPU times: user 1.35 s, sys: 452 ms, total: 1.81 s
Wall time: 1.59 s
In [7]:
g1 = AE_test_loader.dataset[7]
g2 = AE_test_loader.dataset[6]
X = g1.x.numpy() * feature_scale
hits = np.vstack([X.T, g1.I.numpy(), g1.pid.numpy()]).T
sum(hits[:,4] == brian)
Out[7]:
10
In [8]:
np.where(np.unique(hits[:,4])==brian)
Out[8]:
(array([1335]),)

Where is Brian:

In [9]:
len(np.unique(hits[:,4]))
Out[9]:
1345
In [10]:
plt.figure(figsize=(10,10))
for pid in np.unique(hits[:,4])[1330:1340]:
    pid_hits = hits[hits[:,4] == pid]
    x = pid_hits[:,0] * np.cos(pid_hits[:,1])
    y = pid_hits[:,0] * np.sin(pid_hits[:,1])
    plt.scatter(x, y, s=50)

Our friend should be recognisable as the brown track. The entire track is in this event, which is encouraging - it means that the embedding thought the edges connecting the two furthest hits were likely even though the hits were outside of the naive division of the event. A visualisation will help:

In [11]:
%%time
plot_brian(g1, feature_scale, brian)
[[3776 1333 3772 5814 5541 3776 3772 2376 5545  660 1333 1333]
 [5541 3776 5545 1333  660 5545 5541 2376  660 2376 3772 1333]]
CPU times: user 7.45 s, sys: 166 ms, total: 7.62 s
Wall time: 7.62 s

Observe that Brian was not connected in this graph because we (arbitrarily) choose to not connect outside the graph split heading outwards, only inwards. We can examine this decision later... For now, look at the other graph containing Brian.

In [12]:
%%time
plot_brian(g2, feature_scale, brian)
[[2654 2654]
 [3845 3839]]
CPU times: user 8.66 s, sys: 252 ms, total: 8.91 s
Wall time: 8.91 s

A lot to unpack here. A) There are some self-edges. They will be removed by the recent push, I hope. There are also doublets between hits on the same layer.

Edge Distributions

In [34]:
HC_data = HC_test_loader.dataset[0]
HC_e, HC_X = HC_data['edge_index'].numpy(), HC_data['x'].numpy()*feature_scale
In [35]:
AE_data = AE_test_loader.dataset[0]
AE_e, AE_X = AE_data['edge_index'].numpy(), AE_data['x'].numpy()*feature_scale
In [36]:
plt.figure(figsize=(10,10))
ax = sns.distplot(list(Counter(np.hstack([HC_e[0,:], HC_e[1,:]])).values()), kde=False, label="Heuristic", bins=np.linspace(0, 25))
ax = sns.distplot(list(Counter(np.hstack([AE_e[0,:], AE_e[1,:]])).values()), kde=False, label="Embedding", bins=np.linspace(0, 25))
# ax = sns.distplot(list(Counter(np.hstack([AAE_e[0,:], AAE_e[1,:]])).values()), kde=False, label="Augmented Embedding", bins=np.linspace(0, 25))
ax.set(ylabel='Count')
ax.legend()
ax.set_title("Distribution of Edges per Hit")
plt.show()

Clearly, there is a different topology for the two graphs. The "small world"-ness of the heuristic graph is much greater than the embedded graph: Messages can be passed around much more freely.

A different choice of connecting inter-graph

Let's try constructing the track by keeping all connecting edges between graphs.

In [5]:
def construct_AE_graph(X, e, y, I, pid, norm_phi_min, delta, n_phi_sections, augmented):
    """Construct one graph (e.g. from one event)"""
    """ Masks out the edges and hits falling within [phi_min, phi_min + delta_phi]"""
    
    # Mask out phi segment edges and truths, leaving full hit list
    seg_hits = (X[:,1] >= norm_phi_min) & (X[:,1] < (norm_phi_min + delta))
    seg_edges = seg_hits[e[1,:]] | seg_hits[e[0,:]] # Whether to filter by in or out end may impact training, explore this further!
    e_sec = e[:, seg_edges]
    y_sec = y[seg_edges]
    
    # Prepare the features data types    
    y_sec = y_sec.astype(np.float32)
    
    # Option to augment graph with random connections
    if augmented:
        e_sec, y_sec = augment_graph(e_sec, y_sec)
    
    # Reset hit indices to avoid unused hits
    X, e_sec, I, pid = reset_hit_indices(X, e_sec, I, pid)

    # Center phi at 0
    X[:,1] = X[:,1] - norm_phi_min - (delta/2)
    # Handle case of crossing the boundary
    X[X[:,1] < (-n_phi_sections), 1] += 2*n_phi_sections
    X[X[:,1] > (n_phi_sections), 1] -= 2*n_phi_sections
    
    
    graph = Data(x = torch.from_numpy(X).float(), edge_index = torch.from_numpy(e_sec), y = torch.from_numpy(y_sec), I = torch.from_numpy(I), pid = torch.from_numpy(pid))
    
    # Return a tuple of the results
    return graph
In [8]:
%%time
train_size, test_size = 1, 1
out = display(progress(0, train_size), display_id=True)
AE_train_dataset = [prepare_AE_event(event, feature_names, feature_scale, pt_min, n_phi_sections, iter, train_size, out) for (iter, event) in enumerate(all_events[:train_size])]
AE_train_dataset = [datapoint for dataset in AE_train_dataset for datapoint in dataset]
out = display(progress(0, test_size), display_id=True)
AE_test_dataset = [prepare_AE_event(event, feature_names, feature_scale, pt_min, n_phi_sections, iter, test_size, out) for (iter, event) in enumerate(all_events[-test_size:])]
AE_test_dataset = [datapoint for dataset in AE_test_dataset for datapoint in dataset]
AE_train_loader = DataLoader(AE_train_dataset, batch_size=2, shuffle=True)
AE_test_loader = DataLoader(AE_test_dataset, batch_size=2, shuffle=True)
0
0
CPU times: user 1.42 s, sys: 423 ms, total: 1.84 s
Wall time: 1.47 s
In [10]:
g1 = AE_test_loader.dataset[7]
g2 = AE_test_loader.dataset[6]
X = g1.x.numpy() * feature_scale
hits = np.vstack([X.T, g1.I.numpy(), g1.pid.numpy()]).T
sum(hits[:,4] == brian)
Out[10]:
10

Finding Brian

In [11]:
np.where(np.unique(hits[:,4])==brian)
Out[11]:
(array([1465]),)
In [12]:
plt.figure(figsize=(10,10))
for pid in np.unique(hits[:,4])[1460:1470]:
    pid_hits = hits[hits[:,4] == pid]
    x = pid_hits[:,0] * np.cos(pid_hits[:,1])
    y = pid_hits[:,0] * np.sin(pid_hits[:,1])
    plt.scatter(x, y, s=50)

Our friend should be recognisable as the brown track. The entire track is in this event, which is encouraging - it means that the embedding thought the edges connecting the two furthest hits were likely even though the hits were outside of the naive division of the event. A visualisation will help:

In [13]:
%%time
plot_brian(g1, feature_scale, brian)
[[6178 5623 4788 5619 4004 6438 5623 5619 6178 6442  439 6178 4788 4788]
 [2580 6438 5623 6442 4788  439 6442 6438 6178  439 6178 2579 5619 4788]]
CPU times: user 8.22 s, sys: 185 ms, total: 8.41 s
Wall time: 8.41 s

And now, Brian is fully covered by this graph. We got a little lucky, in that he fits on one eighth of the graph, given how much curvature the track has.

In [26]:
%%time
plot_brian(g2, feature_scale, brian)
[[1883 1883]
 [2696 2695]]
CPU times: user 9.57 s, sys: 220 ms, total: 9.79 s
Wall time: 9.79 s

Edge Distributions

In [39]:
HC_data = HC_test_loader.dataset[0]
HC_e, HC_X = HC_data['edge_index'].numpy(), HC_data['x'].numpy()*feature_scale
In [40]:
AE_data = AE_test_loader.dataset[0]
AE_e, AE_X = AE_data['edge_index'].numpy(), AE_data['x'].numpy()*feature_scale
In [41]:
plt.figure(figsize=(10,10))
ax = sns.distplot(list(Counter(np.hstack([HC_e[0,:], HC_e[1,:]])).values()), kde=False, label="Heuristic", bins=np.linspace(0, 25))
ax = sns.distplot(list(Counter(np.hstack([AE_e[0,:], AE_e[1,:]])).values()), kde=False, label="Embedding", bins=np.linspace(0, 25))
# ax = sns.distplot(list(Counter(np.hstack([AAE_e[0,:], AAE_e[1,:]])).values()), kde=False, label="Augmented Embedding", bins=np.linspace(0, 25))
ax.set(ylabel='Count')
ax.legend()
ax.set_title("Distribution of Edges per Hit")
plt.show()

Note that with this choice of split, the edge distribution is shifted to the right. That's good - more edges per hit let the GNN pass messages better.

Classification

Doublet Classification of Embedded Graph

In [3]:
# Load by directory (preferred)
result_base = os.path.expandvars('$SCRATCH/ExaTrkX/processed_sparse/results/')
result_name = 'high_003'
result_dir = os.path.join(result_base, result_name)

config = load_config_dir(result_dir)
print('Configuration:')
pprint.pprint(config)

summaries = load_summaries(config)
best_idx = summaries.valid_loss.idxmin()
print('\nTraining summaries:')
summaries
Configuration:
{'data': {'batch_size': 4,
          'input_dir': '${SCRATCH}/ExaTrkX/processed_sparse/adjacent/doublets/high_fullsplit/',
          'n_train': 56000,
          'n_valid': 1600,
          'n_workers': 4,
          'name': 'hitgraphs_sparse',
          'real_weight': 2.5},
 'model': {'hidden_activation': 'Tanh',
           'hidden_dim': 64,
           'input_dim': 3,
           'layer_norm': True,
           'loss_func': 'binary_cross_entropy_with_logits',
           'n_graph_iters': 8,
           'name': 'resgnn'},
 'n_ranks': 8,
 'optimizer': {'learning_rate': 0.001, 'name': 'Adam', 'weight_decay': 0.0001},
 'output_dir': '/global/cscratch1/sd/danieltm/ExaTrkX/processed_sparse/results/high_003',
 'trainer': {'name': 'gnn_sparse'},
 'training': {'n_total_epochs': 90}}

Training summaries:
Out[3]:
lr train_loss l1 l2 epoch train_time valid_loss valid_acc valid_time
0 0.00100 0.480320 2724.761733 25.529860 0 368.761703 0.213951 0.908253 5.329962
1 0.00100 0.287234 2756.561342 25.511140 1 365.699270 0.166751 0.930555 4.984706
2 0.00100 0.255982 2792.886014 25.511949 2 365.568742 0.184168 0.922824 4.849235
3 0.00100 0.227256 2817.883548 25.491329 3 365.997961 0.131679 0.945866 4.852307
4 0.00100 0.267749 2857.947180 25.503524 4 365.732243 0.139397 0.943319 5.220142
5 0.00100 0.208958 2881.313417 25.480691 5 365.553983 0.120133 0.950475 4.833835
6 0.00100 0.213459 2913.146642 25.503262 6 365.107543 0.124390 0.949635 4.888471
7 0.00100 0.210994 2945.498143 25.523710 7 366.041092 0.119422 0.951258 4.848629
8 0.00100 0.198235 2971.334172 25.539901 8 365.688171 0.106836 0.956552 4.818961
9 0.00100 0.202822 2992.512530 25.517580 9 365.631917 0.109237 0.955272 4.776529
10 0.00100 0.185641 3011.413442 25.499880 10 365.747768 0.103315 0.958062 4.760639
11 0.00100 0.187603 3032.891442 25.493363 11 366.035794 0.106440 0.956849 4.771806
12 0.00100 0.191505 3050.932564 25.464720 12 366.271327 0.104330 0.957649 4.762846
13 0.00100 0.175734 3065.927527 25.444207 13 364.957310 0.146210 0.939502 4.769556
14 0.00100 0.222537 3091.782410 25.505983 14 365.262520 0.113833 0.954556 4.766999
15 0.00100 0.193426 3099.076847 25.444813 15 366.014827 0.123747 0.950101 4.915367
16 0.00100 0.177783 3099.421099 25.368681 16 366.268671 0.110611 0.955038 4.771200
17 0.00100 0.158881 3093.973594 25.266758 17 366.777417 0.090036 0.963953 4.803784
18 0.00100 0.167444 3110.203478 25.269157 18 365.177988 0.093303 0.961949 4.774706
19 0.00100 0.163159 3107.448481 25.195310 19 365.919331 0.100099 0.959381 4.787841
20 0.00100 0.150567 3093.851626 25.082908 20 366.797141 0.090313 0.963345 4.769958
21 0.00100 0.181522 3109.879829 25.115819 21 365.447225 0.088713 0.964599 4.783289
22 0.00100 0.160222 3114.496395 25.069261 22 366.344191 0.090630 0.963576 4.796783
23 0.00100 0.153649 3123.283374 25.062643 23 365.393522 0.089410 0.964002 4.773188
24 0.00100 0.140369 3107.736462 24.926412 24 365.437787 0.082928 0.966686 4.774567
25 0.00100 0.143669 3111.288079 24.891662 25 365.293435 0.081446 0.967578 4.769677
26 0.00100 0.142932 3115.014524 24.869682 26 366.895594 0.084073 0.966650 4.804823
27 0.00100 0.136113 3106.652848 24.787365 27 365.480295 0.080974 0.967713 4.791939
28 0.00100 0.153690 3127.738741 24.844954 28 365.463422 0.124039 0.949855 4.805463
29 0.00100 0.147110 3131.148603 24.843886 29 366.448647 0.081836 0.967400 5.553495
... ... ... ... ... ... ... ... ... ...
60 0.00033 0.095532 2915.934687 23.250625 60 366.000777 0.064084 0.974575 4.795584
61 0.00033 0.094783 2897.436367 23.147550 61 364.935869 0.064359 0.974618 4.808300
62 0.00033 0.094764 2880.410154 23.053366 62 365.761077 0.064131 0.974550 4.824701
63 0.00033 0.094569 2865.207873 22.967907 63 365.429134 0.063808 0.974708 4.795633
64 0.00033 0.094180 2850.607115 22.884840 64 366.307959 0.065540 0.973995 4.883636
65 0.00033 0.094027 2827.471299 22.751862 65 373.210591 0.065064 0.974345 4.795253
66 0.00033 0.093769 2807.311778 22.631057 66 367.854374 0.062123 0.975624 4.810517
67 0.00033 0.093566 2789.466365 22.523850 67 365.895271 0.063452 0.975075 4.791810
68 0.00033 0.105724 2784.652259 22.470618 68 365.378692 0.075112 0.970121 4.833318
69 0.00033 0.094712 2767.970880 22.391676 69 366.127369 0.062879 0.975145 4.794605
70 0.00033 0.094163 2754.577156 22.313336 70 366.059221 0.063067 0.975151 4.794578
71 0.00033 0.092679 2740.316425 22.228727 71 365.503531 0.062416 0.975297 4.824136
72 0.00033 0.092386 2727.083477 22.148959 72 365.562827 0.062609 0.975333 4.829486
73 0.00033 0.098390 2722.006754 22.115117 73 365.577044 0.064350 0.974605 4.793537
74 0.00033 0.091829 2709.728173 22.043364 74 365.613181 0.060843 0.976183 4.813022
75 0.00033 0.091956 2697.864056 21.971483 75 365.718332 0.062496 0.975291 4.810519
76 0.00033 0.102415 2697.067702 21.949026 76 365.452330 0.064937 0.974404 4.809252
77 0.00033 0.094726 2687.966272 21.892341 77 401.582917 0.071561 0.972003 5.456798
78 0.00033 0.092350 2675.621553 21.829146 78 398.595022 0.062841 0.975237 5.388688
79 0.00033 0.095087 2665.496492 21.769134 79 397.906691 0.062010 0.975467 5.375511
80 0.00033 0.098534 2659.300210 21.727648 80 397.729241 0.059998 0.976370 5.362723
81 0.00033 0.096194 2650.890166 21.676059 81 397.215463 0.061306 0.975876 5.404302
82 0.00033 0.099000 2646.589094 21.647228 82 397.799167 0.061109 0.975947 5.351732
83 0.00033 0.089474 2633.469779 21.568967 83 397.395612 0.062251 0.975580 5.368637
84 0.00033 0.106998 2631.495309 21.545330 84 398.952156 0.061966 0.975416 6.060371
85 0.00033 0.106040 2631.161827 21.535411 85 396.832666 0.072841 0.971389 5.382160
86 0.00033 0.093107 2622.532890 21.489563 86 397.995421 0.060113 0.976390 5.399356
87 0.00033 0.091341 2615.749336 21.438266 87 398.841028 0.062724 0.975220 5.382360
88 0.00033 0.097190 2615.351851 21.412423 88 396.774070 0.085683 0.965734 5.362917
89 0.00033 0.091244 2605.896000 21.363345 89 398.031071 0.062929 0.975244 5.364826

90 rows × 9 columns

In [4]:
# Build the trainer and load best checkpoint
trainer = get_trainer(output_dir=config['output_dir'], gpu=0, **config['trainer'])
trainer.build_model(optimizer_config=config['optimizer'], **config['model'])

best_epoch = summaries.epoch.loc[best_idx]
trainer.load_checkpoint(checkpoint_id=best_epoch)

print(trainer.model)
print('Parameters:', sum(p.numel() for p in trainer.model.parameters()))
GNNSegmentClassifier(
  (input_network): Sequential(
    (0): Linear(in_features=3, out_features=64, bias=True)
    (1): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
    (2): Tanh()
  )
  (edge_network): EdgeNetwork(
    (network): Sequential(
      (0): Linear(in_features=128, out_features=64, bias=True)
      (1): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
      (2): Tanh()
      (3): Linear(in_features=64, out_features=64, bias=True)
      (4): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
      (5): Tanh()
      (6): Linear(in_features=64, out_features=64, bias=True)
      (7): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
      (8): Tanh()
      (9): Linear(in_features=64, out_features=1, bias=True)
    )
  )
  (node_network): NodeNetwork(
    (network): Sequential(
      (0): Linear(in_features=192, out_features=64, bias=True)
      (1): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
      (2): Tanh()
      (3): Linear(in_features=64, out_features=64, bias=True)
      (4): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
      (5): Tanh()
      (6): Linear(in_features=64, out_features=64, bias=True)
      (7): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
      (8): Tanh()
      (9): Linear(in_features=64, out_features=64, bias=True)
      (10): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
      (11): Tanh()
    )
  )
)
Parameters: 42753
In [5]:
# Load the test dataset
n_test = 80
test_loader, filelist = get_test_data_loader(config, n_test=n_test)
In [6]:
%%time
# Apply the model
test_preds, test_targets = trainer.device_predict(test_loader)
CPU times: user 5.9 s, sys: 6.25 s, total: 12.1 s
Wall time: 6.3 s
In [9]:
i = 0
g = test_loader.dataset[i]
pid = np.load(filelist[g.i][:-4]+"_ID.npz", allow_pickle=True)["pid"]
X, e, y, o = g.x.numpy()*feature_scale, g.edge_index.numpy(), g.y.numpy(), test_preds[i].numpy()
In [19]:
filelist[g.i]
Out[19]:
'/global/cscratch1/sd/danieltm/ExaTrkX/processed_sparse/adjacent/doublets/high_fullsplit/event9999_7.npz'
In [60]:
def draw_sample_brian_xy(hits, edges, preds, brian, pid, labels, cut=0.5, figsize=(16, 16)):
    x = hits[:,0] * np.cos(hits[:,1])
    y = hits[:,0] * np.sin(hits[:,1])
    fig, ax0 = plt.subplots(figsize=figsize)
    brian_filter = pid == brian
    p_brian_edges = edges[:,(brian_filter[edges[0,:]]) & (brian_filter[edges[1,:]]) & (preds > cut)]
    n_brian_edges = edges[:,(brian_filter[edges[0,:]]) & (brian_filter[edges[1,:]]) & (preds < cut)]
    
    # Draw the hits
    ax0.scatter(x, y, s=2, c='k', alpha=0.1)

    # Draw the segments
    for j in range(labels.shape[0]):

        # False negatives
        if preds[j] < cut and labels[j] > cut:
            ax0.plot([x[edges[0,j]], x[edges[1,j]]],
                     [y[edges[0,j]], y[edges[1,j]]],
                     '--', c='b', alpha=0.1)

        # False positives
        if preds[j] > cut and labels[j] < cut:
            ax0.plot([x[edges[0,j]], x[edges[1,j]]],
                     [y[edges[0,j]], y[edges[1,j]]],
                     '-', c='r', alpha=0.1)

        # True positives
        if preds[j] > cut and labels[j] > cut:
            ax0.plot([x[edges[0,j]], x[edges[1,j]]],
                     [y[edges[0,j]], y[edges[1,j]]],
                     '-', c='k', alpha=0.1)
   
    ax0.plot([x[p_brian_edges[0,:]], x[p_brian_edges[1,:]]], [y[p_brian_edges[0,:]], y[p_brian_edges[1,:]]], c='k', linewidth=3, alpha=0.9)
    ax0.plot([x[n_brian_edges[0,:]], x[n_brian_edges[1,:]]], [y[n_brian_edges[0,:]], y[n_brian_edges[1,:]]], c='r', linewidth=3, alpha=0.9)
    ax0.scatter(x[brian_filter], y[brian_filter], c='k', s=50, alpha=0.9)
            
    return fig, ax0
In [51]:
%%time
draw_sample_xy(X, e, o, y)
CPU times: user 13.9 s, sys: 124 ms, total: 14 s
Wall time: 14 s
Out[51]:
(<Figure size 1152x1152 with 1 Axes>,
 <matplotlib.axes._subplots.AxesSubplot at 0x2aab6c3d2748>)
In [12]:
np.where(pid[e[0,o>0.5]]==brian)
Out[12]:
(array([2023, 2719, 2721, 3375, 3376, 3380, 3381, 4036, 4037, 4616, 5039,
        5040]),)
In [13]:
np.where(pid[e[1,o>0.5]]==brian)
Out[13]:
(array([2023, 2719, 2721, 3375, 3376, 3380, 3381, 4036, 4037, 4616, 5039,
        5040]),)
In [14]:
e_pid = pid[e[0,:]] * y
In [15]:
cut = 0.5
p_edges = e[:,o>cut]
n_edges = e[:,o<cut]
brian_p_edges = p_edges[:,e_pid[o>cut]==brian]
brian_n_edges = n_edges[:,e_pid[o<cut]==brian]
full_brian = np.hstack([brian_n_edges,brian_p_edges])
In [16]:
brian_p_edges
Out[16]:
array([[4005, 4786, 4786, 5617, 5621, 5617, 5621, 6433, 6437,  433, 1688,
        1688],
       [4786, 5617, 5621, 6433, 6433, 6437, 6437,  433,  433, 1688, 2581,
        2582]])
In [17]:
brian_n_edges
Out[17]:
array([], shape=(2, 0), dtype=int64)
In [18]:
plt.figure(figsize=(20,10))
x = X[:,0] * np.cos(X[:,1])
y = X[:,0] * np.sin(X[:,1])
for edge in full_brian.T[:]:
    plt.plot([x[edge[0]], x[edge[1]]], [y[edge[0]], y[edge[1]]], linewidth=1, alpha=0.9)
    plt.scatter(x[edge], y[edge])
In [27]:
plt.figure(figsize=(20,10))
x = X[:,0] * np.cos(X[:,1])
y = X[:,0] * np.sin(X[:,1])
for edge in brian_p_edges.T[:]:
    plt.plot([x[edge[0]], x[edge[1]]], [y[edge[0]], y[edge[1]]], linewidth=1, alpha=0.9)
    plt.scatter(x[edge], y[edge])
# plt.plot([x[n_brian_edges[0,:]], x[n_brian_edges[1,:]]], [y[n_brian_edges[0,:]], y[n_brian_edges[1,:]]], c='r', linewidth=2, alpha=0.9)
In [28]:
plt.figure(figsize=(20,10))
x = X[:,0] * np.cos(X[:,1])
y = X[:,0] * np.sin(X[:,1])
for edge in brian_n_edges.T[:]:
    plt.plot([x[edge[0]], x[edge[1]]], [y[edge[0]], y[edge[1]]], linewidth=1, alpha=0.9)
    plt.scatter(x[edge], y[edge])
In [32]:
plt.figure(figsize=(20,10))
x = X[:,0] * np.cos(X[:,1])
y = X[:,0] * np.sin(X[:,1])
plt.plot([x[brian_p_edges[0,:]], x[brian_p_edges[1,:]]], [y[brian_p_edges[0,:]], y[brian_p_edges[1,:]]], c='k', linewidth=2, alpha=0.9)
plt.plot([x[brian_n_edges[0,:]], x[brian_n_edges[1,:]]], [y[brian_n_edges[0,:]], y[brian_n_edges[1,:]]], c='r', linewidth=2, alpha=0.9)
Out[32]:
[<matplotlib.lines.Line2D at 0x2aab3ef1db70>]
In [61]:
%%time
draw_sample_brian_xy(X, e, o, brian, pid, y, cut=0.5)
CPU times: user 14.7 s, sys: 160 ms, total: 14.9 s
Wall time: 14.9 s
Out[61]:
(<Figure size 1152x1152 with 1 Axes>,
 <matplotlib.axes._subplots.AxesSubplot at 0x2aabcdc328d0>)

The takeaway: Brian has been quite successfully classified. To be clear, he was not chosen as a track that would be easily classified, he was chosen at random as a difficult track a priori without knowing the success of his classification. He is still missing one segment however. So we can attempt to classify him with a triplet representation.

Some different GNN architectures

In [ ]:
 

Triplet Classification of Embedded Graph

We perform a very similar analyses as in the doublet case, but using the doublet scores to train the triplet scores.

In [5]:
# Load by directory (preferred)
result_base = os.path.expandvars('$SCRATCH/ExaTrkX/processed_sparse/results/triplets')
result_name = 'high_lowcut_005'
result_dir = os.path.join(result_base, result_name)

config = load_config_dir(result_dir)
print('Configuration:')
pprint.pprint(config)

summaries = load_summaries(config)
best_idx = summaries.valid_loss.idxmin()
print('\nTraining summaries:')
summaries
Configuration:
{'data': {'batch_size': 2,
          'input_dir': '${SCRATCH}/ExaTrkX/processed_sparse/adjacent/triplets/high_lowcut_003/',
          'n_train': 56000,
          'n_valid': 1600,
          'n_workers': 4,
          'name': 'hitgraphs_sparse',
          'real_weight': 2},
 'model': {'hidden_activation': 'Tanh',
           'hidden_dim': 128,
           'input_dim': 7,
           'layer_norm': True,
           'loss_func': 'binary_cross_entropy_with_logits',
           'n_graph_iters': 4,
           'name': 'resgnn'},
 'n_ranks': 8,
 'optimizer': {'learning_rate': 0.0005, 'name': 'Adam', 'weight_decay': 0.0001},
 'output_dir': '/global/cscratch1/sd/danieltm/ExaTrkX/processed_sparse/results/triplets/high_lowcut_005',
 'project': 'Embedded-built-testing',
 'trainer': {'name': 'gnn_sparse'},
 'training': {'n_total_epochs': 90}}

Training summaries:
Out[5]:
lr train_loss l1 l2 epoch train_time valid_loss valid_acc valid_time
0 0.000500 0.073155 5198.984462 23.492924 0 336.094266 0.053852 0.982889 3.754843
1 0.000500 0.061457 3673.209646 18.289839 1 334.038802 0.035929 0.987122 3.716984
2 0.000500 0.044579 3031.657952 16.225513 2 333.802999 0.028767 0.990040 3.709755
3 0.000500 0.037699 2665.391043 15.052534 3 333.993042 0.027278 0.990536 3.699720
4 0.000500 0.038018 2546.416082 14.468287 4 333.856523 0.027039 0.990984 3.713910
5 0.000500 0.033454 2233.187310 13.848771 5 334.360373 0.022666 0.992246 3.754658
6 0.000500 0.027984 1971.920580 13.248920 6 333.691531 0.021675 0.992530 4.278229
7 0.000500 0.026802 1763.009086 12.760698 7 334.101805 0.019268 0.993434 3.713725
8 0.000500 0.025718 1621.243451 12.376881 8 334.205973 0.018878 0.993589 3.698469
9 0.000500 0.024842 1504.260715 12.057010 9 334.365024 0.019006 0.993660 3.712008
10 0.000500 0.024704 1443.613328 11.809676 10 333.173209 0.025395 0.992167 3.697717
11 0.000500 0.024270 1400.301292 11.613951 11 334.263838 0.019261 0.993390 3.731421
12 0.000500 0.023608 1303.193156 11.377265 12 334.911291 0.018002 0.993802 3.703226
13 0.000500 0.023770 1278.518140 11.232848 13 333.855265 0.018015 0.993842 3.713850
14 0.000500 0.023327 1242.392880 11.074830 14 334.333294 0.017614 0.993950 3.721895
15 0.000500 0.022572 1178.716929 10.898969 15 334.202676 0.017445 0.994072 3.865875
16 0.000500 0.022634 1145.573797 10.770276 16 334.634353 0.017842 0.994089 4.198804
17 0.000500 0.022150 1108.661814 10.637860 17 334.986716 0.017853 0.994009 3.709605
18 0.000500 0.021957 1080.883208 10.521448 18 334.150733 0.017206 0.994165 3.710685
19 0.000500 0.021777 1056.529619 10.414899 19 334.659640 0.017047 0.994170 3.709027
20 0.000500 0.021716 1042.085304 10.324113 20 334.827372 0.019073 0.993489 3.719098
21 0.000500 0.021487 1025.919497 10.233804 21 334.490094 0.016670 0.994455 4.495259
22 0.000500 0.021345 1008.528245 10.147825 22 334.477725 0.015852 0.994638 3.743489
23 0.000500 0.021297 980.689998 10.075415 23 334.255198 0.016388 0.994412 3.718493
24 0.000500 0.021002 972.492742 10.004570 24 334.549976 0.016582 0.994381 3.711129
25 0.000500 0.021119 965.907075 9.951418 25 335.191045 0.017239 0.994250 3.713901
26 0.000500 0.020684 931.397900 9.878254 26 335.104271 0.015838 0.994666 3.705183
27 0.000500 0.020808 952.513400 9.844192 27 333.960695 0.016462 0.994412 3.703416
28 0.000500 0.023269 1331.069203 10.138618 28 334.113100 0.028054 0.990402 3.711659
29 0.000500 0.023400 993.566760 9.875866 29 334.190578 0.016630 0.994306 3.704496
... ... ... ... ... ... ... ... ... ...
60 0.000165 0.017023 907.322291 9.050642 60 334.294349 0.013150 0.995522 3.709404
61 0.000165 0.016848 903.001985 9.037481 61 334.283357 0.013278 0.995524 3.990748
62 0.000165 0.016824 895.935292 9.023551 62 335.250902 0.013568 0.995437 3.746364
63 0.000165 0.016799 891.925294 9.010123 63 335.710876 0.013073 0.995623 3.718560
64 0.000165 0.016653 893.119295 8.997444 64 334.331829 0.012960 0.995563 3.728253
65 0.000165 0.016624 891.952536 8.984923 65 334.090067 0.014043 0.995332 3.740708
66 0.000165 0.016625 887.446463 8.971713 66 335.403013 0.012680 0.995697 3.753984
67 0.000165 0.016622 896.694958 8.961208 67 336.439976 0.013154 0.995502 3.710060
68 0.000165 0.016529 879.180894 8.946996 68 334.739685 0.013540 0.995414 4.462043
69 0.000165 0.016516 885.668402 8.934668 69 334.013649 0.012692 0.995706 3.726285
70 0.000165 0.016449 878.900937 8.922248 70 334.674348 0.012934 0.995642 3.731372
71 0.000165 0.016399 875.382774 8.909185 71 334.894490 0.013409 0.995433 3.773468
72 0.000165 0.016331 875.614434 8.895917 72 334.575586 0.012926 0.995602 3.734381
73 0.000165 0.016386 874.979403 8.882985 73 333.800678 0.012823 0.995671 3.724993
74 0.000165 0.016279 871.203912 8.871495 74 334.319071 0.013369 0.995440 4.053293
75 0.000165 0.016199 864.120380 8.858350 75 334.016477 0.013274 0.995529 3.797201
76 0.000165 0.016235 871.410083 8.847809 76 333.769648 0.013416 0.995543 3.719641
77 0.000165 0.016076 864.074029 8.834390 77 333.731056 0.012576 0.995761 3.714146
78 0.000165 0.016139 856.752898 8.822114 78 334.745253 0.012639 0.995741 3.740770
79 0.000165 0.016183 855.199186 8.809951 79 334.809445 0.012656 0.995739 3.728658
80 0.000165 0.016063 860.615965 8.799899 80 335.512459 0.012751 0.995696 3.748494
81 0.000165 0.016025 854.298714 8.787507 81 333.710150 0.012958 0.995586 3.732986
82 0.000165 0.015945 854.030358 8.777333 82 334.447678 0.012287 0.995859 3.747921
83 0.000165 0.015980 848.153074 8.766873 83 334.245091 0.012488 0.995814 4.021411
84 0.000165 0.015724 847.909295 8.730907 84 350.146532 0.012763 0.995655 4.340663
85 0.000165 0.015645 833.991985 8.699670 85 334.624476 0.012092 0.995896 3.712860
86 0.000165 0.015614 832.306602 8.674569 86 334.192294 0.012626 0.995745 3.693188
87 0.000165 0.015540 819.720884 8.649986 87 335.367037 0.011737 0.996046 3.682312
88 0.000165 0.015551 820.074120 8.628566 88 332.963703 0.012188 0.995879 3.696671
89 0.000165 0.015462 819.549706 8.608182 89 334.192140 0.012069 0.995887 3.703552

90 rows × 9 columns

In [6]:
# Build the trainer and load best checkpoint
trainer = get_trainer(output_dir=config['output_dir'], gpu=0, **config['trainer'])
trainer.build_model(optimizer_config=config['optimizer'], **config['model'])

best_epoch = summaries.epoch.loc[best_idx]
trainer.load_checkpoint(checkpoint_id=best_epoch)

print(trainer.model)
print('Parameters:', sum(p.numel() for p in trainer.model.parameters()))
GNNSegmentClassifier(
  (input_network): Sequential(
    (0): Linear(in_features=7, out_features=128, bias=True)
    (1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
    (2): Tanh()
  )
  (edge_network): EdgeNetwork(
    (network): Sequential(
      (0): Linear(in_features=256, out_features=128, bias=True)
      (1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
      (2): Tanh()
      (3): Linear(in_features=128, out_features=128, bias=True)
      (4): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
      (5): Tanh()
      (6): Linear(in_features=128, out_features=128, bias=True)
      (7): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
      (8): Tanh()
      (9): Linear(in_features=128, out_features=1, bias=True)
    )
  )
  (node_network): NodeNetwork(
    (network): Sequential(
      (0): Linear(in_features=384, out_features=128, bias=True)
      (1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
      (2): Tanh()
      (3): Linear(in_features=128, out_features=128, bias=True)
      (4): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
      (5): Tanh()
      (6): Linear(in_features=128, out_features=128, bias=True)
      (7): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
      (8): Tanh()
      (9): Linear(in_features=128, out_features=128, bias=True)
      (10): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
      (11): Tanh()
    )
  )
)
Parameters: 167937
In [7]:
# Load the test dataset
n_test = 80
test_loader, filelist = get_test_data_loader(config, n_test=n_test)
In [8]:
%%time
# Apply the model
test_preds, test_targets = trainer.device_predict(test_loader)
CPU times: user 4.4 s, sys: 5.16 s, total: 9.56 s
Wall time: 5.42 s
In [45]:
i = 0
g = test_loader.dataset[i]
pid = np.load(filelist[g.i][:-4]+"_ID.npz", allow_pickle=True)["pid"]
X, e, y, o = g.x.numpy()*np.hstack([feature_scale, feature_scale, 1]), g.edge_index.numpy(), g.y.numpy(), test_preds[i].numpy()
In [13]:
len(pid.nonzero()[0])
Out[13]:
5260

Some Multiplicities

In [14]:
%%time
draw_triplets_tf_mul_xy(X, e, o, y)
CPU times: user 9.87 s, sys: 39.8 ms, total: 9.91 s
Wall time: 9.92 s
Out[14]:
(<Figure size 1152x1152 with 1 Axes>,
 <matplotlib.axes._subplots.AxesSubplot at 0x2aaaae492400>)
In [15]:
%%time
draw_triplets_xy_antiscore_cut_edges(X, e, o, y)
CPU times: user 19.7 s, sys: 175 ms, total: 19.8 s
Wall time: 19.8 s
Out[15]:
(<Figure size 1152x1152 with 1 Axes>,
 <matplotlib.axes._subplots.AxesSubplot at 0x2aab8dd3dd68>)
In [18]:
%%time
draw_triplets_xy_antiscore_cut(X, e, o, y)
Overperforms by:  595 , underperforms by:  41 .
CPU times: user 1.3 s, sys: 3.51 ms, total: 1.3 s
Wall time: 1.3 s
Out[18]:
(<Figure size 1152x1152 with 1 Axes>,
 <matplotlib.axes._subplots.AxesSubplot at 0x2aabd90ac550>)
In [19]:
%%time
draw_triplets_xy(X, e, o, y)
CPU times: user 9.45 s, sys: 110 ms, total: 9.56 s
Wall time: 9.53 s
Out[19]:
(<Figure size 1152x1152 with 1 Axes>,
 <matplotlib.axes._subplots.AxesSubplot at 0x2aabd9872cc0>)

Brian in Triplets

Brian should have 14 triplet components.

In [71]:
# Throw this away later!!!!!!!!!!!!!!!!
brian = pid[pid>0][1000] 
In [104]:
cut=0.5
brians = []
for b in np.unique(pid[pid>0]):
    brian = b
    if len(np.where(pid[e[0,o>cut]] == brian)[0]) > 8: brians.append(brian)
In [105]:
len(brians)
Out[105]:
221
In [106]:
brian = brians[0]
In [21]:
np.where(pid[e[0,o>0.5]]==brian)
Out[21]:
(array([1785, 1786, 2506, 2507, 2508, 2509, 3144, 3145, 3152, 3153, 3681,
        3682, 4099, 4100]),)
In [22]:
np.where(pid[e[1,o>0.5]]==brian)
Out[22]:
(array([1785, 1786, 2506, 2507, 2508, 2509, 3144, 3145, 3152, 3153, 3681,
        3682, 4099, 4100]),)
In [23]:
np.where(pid[e[0,o<0.5]]==brian)
Out[23]:
(array([ 3098,  3099,  3100,  3101,  3102,  8016,  8017,  8018,  8019,
         8020,  8021,  8022,  8023,  8031,  8032,  8033,  8034,  8035,
         8036, 11598, 11599, 11600, 11601, 14119, 14120, 14121, 14122,
        14123, 15671, 15672, 15673, 15674]),)
In [24]:
np.where(pid[e[1,o<0.5]]==brian)
Out[24]:
(array([ 8013,  8029,  8039, 14127, 14128, 14134, 14135, 14141, 14142,
        14148, 14149, 14155, 14156, 14162, 14163, 14169, 14170, 14176,
        14177, 14183, 14184]),)
In [25]:
e_pid = pid[e[0,:]] * y
In [27]:
cut = 0.5
p_edges = e[:,o>cut]
n_edges = e[:,o<cut]
brian_p_edges = p_edges[:,e_pid[o>cut]==brian]
brian_n_edges = n_edges[:,e_pid[o<cut]==brian]
full_brian = np.hstack([brian_n_edges,brian_p_edges])
In [28]:
brian_p_edges
Out[28]:
array([[3456, 3456, 4448, 4448, 4450, 4450, 5815, 5816, 5823, 5824, 7450,
        7451, 8898, 8898],
       [4448, 4450, 5815, 5823, 5816, 5824, 7450, 7450, 7451, 7451, 8898,
        8898, 9931, 9932]])
In [29]:
brian_n_edges
Out[29]:
array([], shape=(2, 0), dtype=int64)
In [30]:
full_brian
Out[30]:
array([[3456, 3456, 4448, 4448, 4450, 4450, 5815, 5816, 5823, 5824, 7450,
        7451, 8898, 8898],
       [4448, 4450, 5815, 5823, 5816, 5824, 7450, 7450, 7451, 7451, 8898,
        8898, 9931, 9932]])
In [31]:
plt.figure(figsize=(20,10))
xi, yi = X[:,0] * np.cos(X[:,1]), X[:,0] * np.sin(X[:,1])
xo, yo = X[:,3] * np.cos(X[:,4]), X[:,3] * np.sin(X[:,4])
for edge in full_brian.T[:]:
    plt.plot([xi[edge[0]], xi[edge[1]]], [yi[edge[0]], yi[edge[1]]], linewidth=1, alpha=0.9)
    plt.plot([xo[edge[0]], xo[edge[1]]], [yo[edge[0]], yo[edge[1]]], linewidth=1, alpha=0.9)
    plt.scatter(xi[edge], yi[edge])
    plt.scatter(xo[edge], yo[edge])
In [35]:
plt.figure(figsize=(20,10))
xi, yi = X[:,0] * np.cos(X[:,1]), X[:,0] * np.sin(X[:,1])
xo, yo = X[:,3] * np.cos(X[:,4]), X[:,3] * np.sin(X[:,4])
for edge in brian_p_edges.T[:]:
    plt.plot([xi[edge[0]], xi[edge[1]]], [yi[edge[0]], yi[edge[1]]], linewidth=1, alpha=0.9)
    plt.plot([xo[edge[0]], xo[edge[1]]], [yo[edge[0]], yo[edge[1]]], linewidth=1, alpha=0.9)
    plt.scatter(xi[edge], yi[edge])
    plt.scatter(xo[edge], yo[edge])
In [33]:
plt.figure(figsize=(20,10))
x = X[:,0] * np.cos(X[:,1])
y = X[:,0] * np.sin(X[:,1])
for edge in brian_n_edges.T[:]:
    plt.plot([x[edge[0]], x[edge[1]]], [y[edge[0]], y[edge[1]]], linewidth=1, alpha=0.9)
    plt.scatter(x[edge], y[edge])
<Figure size 1440x720 with 0 Axes>
In [34]:
plt.figure(figsize=(20,10))
x = X[:,0] * np.cos(X[:,1])
y = X[:,0] * np.sin(X[:,1])
plt.plot([x[brian_p_edges[0,:]], x[brian_p_edges[1,:]]], [y[brian_p_edges[0,:]], y[brian_p_edges[1,:]]], c='k', linewidth=2, alpha=0.9)
plt.plot([x[brian_n_edges[0,:]], x[brian_n_edges[1,:]]], [y[brian_n_edges[0,:]], y[brian_n_edges[1,:]]], c='r', linewidth=2, alpha=0.9)
Out[34]:
[]
In [49]:
def draw_brian_triplets_xy(hits, edges, preds, brian, pid, labels, cut=0.5, figsize=(16, 16)):
    xi, yi = [hits[:,0] * np.cos(hits[:,1]), hits[:,0] * np.sin(hits[:,1])]
    xo, yo = [hits[:,3] * np.cos(hits[:,4]), hits[:,3] * np.sin(hits[:,4])]
    fig, ax0 = plt.subplots(figsize=figsize)

    brian_filter = pid == brian
    p_brian_edges = edges[:,(brian_filter[edges[0,:]]) & (brian_filter[edges[1,:]]) & (preds > cut)]
    n_brian_edges = edges[:,(brian_filter[edges[0,:]]) & (brian_filter[edges[1,:]]) & (preds < cut)]
    
    #Draw the hits
    ax0.scatter(xi, yi, s=2, c='k')

    # Draw the segments
    for j in range(labels.shape[0]):

        # False negatives
        if preds[j] < cut and labels[j] > cut:
            ax0.plot([xi[edges[0,j]], xo[edges[0,j]]],
                     [yi[edges[0,j]], yo[edges[0,j]]],
                     '--', c='b', alpha=0.6)
            ax0.plot([xi[edges[1,j]], xo[edges[1,j]]],
                     [yi[edges[1,j]], yo[edges[1,j]]],
                     '--', c='b', alpha=0.6)

        # False positives
        if preds[j] > cut and labels[j] < cut:
            ax0.plot([xi[edges[0,j]], xo[edges[0,j]]],
                     [yi[edges[0,j]], yo[edges[0,j]]],
                     '-', c='r', alpha=0.6)
            ax0.plot([xi[edges[1,j]], xo[edges[1,j]]],
                     [yi[edges[1,j]], yo[edges[1,j]]],
                     '-', c='r', alpha=0.6)

        # True positives
        if preds[j] > cut and labels[j] > cut:
            ax0.plot([xi[edges[0,j]], xo[edges[0,j]]],
                     [yi[edges[0,j]], yo[edges[0,j]]],
                     '-', c='k', alpha=0.01)
            ax0.plot([xi[edges[1,j]], xo[edges[1,j]]],
                     [yi[edges[1,j]], yo[edges[1,j]]],
                     '-', c='k', alpha=0.01)

    ax0.plot([xi[p_brian_edges[0,:]], xi[p_brian_edges[1,:]]], [yi[p_brian_edges[0,:]], yi[p_brian_edges[1,:]]], c='k', linewidth=3, alpha=0.9)
    ax0.plot([xi[n_brian_edges[0,:]], xi[n_brian_edges[1,:]]], [yi[n_brian_edges[0,:]], yi[n_brian_edges[1,:]]], c='r', linewidth=3, alpha=0.9)
    ax0.plot([xo[p_brian_edges[0,:]], xo[p_brian_edges[1,:]]], [yo[p_brian_edges[0,:]], yo[p_brian_edges[1,:]]], c='k', linewidth=3, alpha=0.9)
    ax0.plot([xo[n_brian_edges[0,:]], xo[n_brian_edges[1,:]]], [yo[n_brian_edges[0,:]], yo[n_brian_edges[1,:]]], c='r', linewidth=3, alpha=0.9)
    ax0.scatter(xi[brian_filter], yi[brian_filter], c='k', s=50, alpha=0.9)
    ax0.scatter(xo[brian_filter], yo[brian_filter], c='k', s=50, alpha=0.9)
            
    return fig, ax0
In [38]:
e
Out[38]:
array([[    1,     1,     1, ..., 10874, 10875, 10876],
       [   45,    54,    61, ..., 10819, 10819, 10819]])
In [57]:
o > 0.5
Out[57]:
array([False, False, False, ..., False, False, False])
In [54]:
len(y)
Out[54]:
20830
In [55]:
%%time
n_edges = 20830
draw_brian_triplets_xy(X, e[:,:n_edges], o[:n_edges], brian, pid, y[:n_edges], cut=0.5)
CPU times: user 28 s, sys: 243 ms, total: 28.3 s
Wall time: 28.3 s
Out[55]:
(<Figure size 1152x1152 with 1 Axes>,
 <matplotlib.axes._subplots.AxesSubplot at 0x2aabfef4f0f0>)

Now we see that Brian (thick black) is perfectly classified, with every combination of his triplets included as an above-threshold prediction.

In [ ]: